#!/usr/bin/env python
# Copyright 2021-2024 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import pyscf
from pyscf import lib
from pyscf.geomopt.geometric_solver import optimize

from gpu4pyscf.dft import rks

lib.num_threads(16)

# Enzalutamide
# https://www.echemi.com/cms/408806.html
# Enzalutamide is an androgen receptor (AR) inhibitor,
# which was launched in 2012 for the treatment of castrated prostate cancer and metastatic hormone-sensitive prostate cancer.
# coords in angstrom, generated by rdkit
atom = '''
C       -0.52489785       0.25103945      -2.93015020
C       -0.50680345      -0.64393976      -1.68199231
C       -1.91157268      -0.53966994      -1.07689036
O       -2.89596008      -1.08758214      -1.56376930
N       -1.87455596       0.28604347       0.02890000
C       -0.55589417       0.64601028       0.30659520
S       -0.03771290       1.85199711       1.35883674
N        0.28457771      -0.08561063      -0.58428830
C        1.69095398      -0.30104521      -0.44016185
C        2.53850590      -0.42689013      -1.54537828
C        3.90805921      -0.65833698      -1.38306826
C        4.46679033      -0.77942388      -0.11616884
C        3.63537497      -0.68825726       1.00120202
C        2.26303281      -0.45702036       0.83390378
C        5.91434343      -1.04523359       0.03306867
O        6.48698930      -1.82113253      -0.71992937
N        6.53927381      -0.36651782       1.06132882
C        7.96812247      -0.45662800       1.23170297
F        4.66727037      -0.72425525      -2.48600744
C       -3.05365343       0.62887069       0.75835844
C       -4.27570462       0.82269353       0.09079407
C       -5.46414648       1.14892461       0.78005884
C       -5.42793076       1.27078784       2.17688071
C       -4.22194909       1.06521787       2.85487200
C       -3.05316878       0.74497593       2.15329820
C       -6.60149929       1.59446376       2.93995702
N       -7.54355554       1.85428697       3.56412460
C       -6.74928175       1.36174974       0.01633060
F       -7.27014744       2.60177958       0.21630004
F       -6.60372258       1.23850678      -1.33208742
F       -7.71631245       0.47409741       0.36837460
C       -0.22537027      -2.11184985      -2.01127572
H       -0.73323150       1.29620611      -2.67177714
H        0.42028392       0.23270869      -3.47581552
H       -1.30449790      -0.07030990      -3.63124037
H        2.18898530      -0.32226984      -2.56341479
H        4.03255894      -0.83066806       2.00353550
H        1.64713582      -0.45176503       1.73036295
H        6.06116959       0.41424928       1.49170789
H        8.45082885       0.17225656       0.47895774
H        8.22197650      -0.09849639       2.23196987
H        8.30061119      -1.49102468       1.10983334
H       -4.31936045       0.74694915      -0.99665531
H       -4.18054201       1.14208036       3.94151726
H       -2.15934732       0.54471490       2.73896838
H       -0.25324848      -2.73066872      -1.10671674
H       -0.97426928      -2.51384795      -2.70351794
H        0.75149209      -2.25342097      -2.48143483
'''

xc = 'HYB_MGGA_XC_WB97M_V'
bas = 'def2-tzvpp'
auxbasis = 'def2-tzvpp-jkfit'
scf_tol = 1e-10
max_scf_cycles = 200
screen_tol = 1e-14
grids_level = 3
mol = pyscf.M(atom=atom, basis=bas, max_memory=120000)

mol.verbose = 1
mf_GPU = rks.RKS(mol, xc=xc, disp=None).density_fit(auxbasis=auxbasis)
mf_GPU.grids.level = grids_level
mf_GPU.conv_tol = scf_tol
mf_GPU.max_cycle = max_scf_cycles
mf_GPU.screen_tol = screen_tol

gradients = []


def callback(envs):
    gradients.append(envs['gradients'])


start_time = time.time()
mol_eq = optimize(
    mf_GPU,
    maxsteps=500000000,
    constraints='geometric_scan.txt',  # atom index is 1-based in this file
    callback=callback)
print("Optimized coordinate:")
print(mol_eq.atom_coords())
print(time.time() - start_time)
